# Copyright 2020 Makani Technologies LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A library for interacting with the AIO messages and nodes.

This module provides an interface for python scripts to interact with the
AIO network on the M600.  It relies on H2PY generated files to understand
message struct definitions, AIO node enums, message type enums, etc.
"""

from __future__ import absolute_import
from __future__ import print_function
import collections
import socket
import struct
import time

from makani.avionics.common import aio_version
from makani.avionics.common import network_config
from makani.avionics.common import pack_aio_header as aio_header
from makani.avionics.common import pack_avionics_messages
from makani.avionics.network import aio_node
from makani.avionics.network import aio_node_to_ip_address
from makani.avionics.network import message_type as aio_message_type
from makani.control import pack_control_telemetry
from makani.control import pack_ground_telemetry
from makani.lib.python import c_helpers
from makani.sim import pack_sim_messages
from makani.sim import pack_sim_telemetry


class InvalidAioNodeException(Exception):
  """Exception indicating an invalid AIO node was passed."""
  pass


aio_node_helper = c_helpers.EnumHelper('AioNode', aio_node)
message_type_helper = c_helpers.EnumHelper('MessageType', aio_message_type)


def AioNodeToIpAddressString(node):
  """Get an IP address string for a given AIO node.

  Args:
    node: AIO node number.

  Returns:
    A string containing the IP address of an AIO node.
  """
  ip = aio_node_to_ip_address.AioNodeToIpAddress(node)
  return '%d.%d.%d.%d' % (ip.a, ip.b, ip.c, ip.d)


def AioMessageTypeToIpAddressString(message_type):
  """Get an IP address string associated with a message type.

  Args:
    message_type: Message type to look up.

  Returns:
    A string containing the multicast IP address for message_type.
  """
  ip = network_config.AioMessageTypeToIpAddress(message_type)
  return '%d.%d.%d.%d' % (ip.a, ip.b, ip.c, ip.d)


class AioClientException(Exception):
  """Exception generated by AioClient."""
  pass


def GetAioMessageStruct(message_type_name):
  """Get an AIO message class from its enum name.

  Args:
    message_type_name: Name of the MessageType enum, e.g.
        'kMessageTypeMotorStatus'.  This is mapped to a corresponding
        message struct in the following way:
        kMessageTypeMotorStatus -> MotorStatusMessage

  Returns:
    Message ctype struct.

  Raises:
    AioClientException: Raised if the message type is invalid.
  """
  try:
    if message_type_name == 'kMessageTypeControlTelemetry':
      return getattr(pack_control_telemetry, 'ControlTelemetry')
    elif message_type_name == 'kMessageTypeControlSlowTelemetry':
      return getattr(pack_control_telemetry, 'ControlSlowTelemetry')
    elif message_type_name == 'kMessageTypeControlDebug':
      return getattr(pack_control_telemetry, 'ControlDebugMessage')
    elif message_type_name == 'kMessageTypeSimTelemetry':
      return getattr(pack_sim_telemetry, 'SimTelemetry')
    elif message_type_name == 'kMessageTypeGroundTelemetry':
      return getattr(pack_ground_telemetry, 'GroundTelemetry')
    elif message_type_name in ('kMessageTypeDynamicsReplay',
                               'kMessageTypeEstimatorReplay',
                               'kMessageTypeSimCommand',
                               'kMessageTypeSimSensor',
                               'kMessageTypeSimTetherDown'):
      return getattr(pack_sim_messages,
                     message_type_name[len('kMessageType'):] + 'Message')
    else:
      return getattr(pack_avionics_messages,
                     message_type_name[len('kMessageType'):] + 'Message')
  except AttributeError:
    raise AioClientException(
        'No struct for AIO message type: ' + message_type_name)


def GetPackMessageSize(message_type_name):
  """Get the size of the packed messages in bytes."""
  if message_type_name == 'kMessageTypeControlTelemetry':
    return pack_control_telemetry.PACK_CONTROLTELEMETRY_SIZE
  elif message_type_name == 'kMessageTypeControlSlowTelemetry':
    return pack_control_telemetry.PACK_CONTROLSLOWTELEMETRY_SIZE
  elif message_type_name == 'kMessageTypeControlDebug':
    return pack_control_telemetry.PACK_CONTROLDEBUGMESSAGE_SIZE
  elif message_type_name == 'kMessageTypeSimTelemetry':
    return pack_sim_telemetry.PACK_SIMTELEMETRY_SIZE
  elif message_type_name == 'kMessageTypeGroundTelemetry':
    return pack_ground_telemetry.PACK_GROUNDTELEMETRY_SIZE
  elif message_type_name in ('kMessageTypeSimCommand',
                             'kMessageTypeSimSensor',
                             'kMessageTypeSimTetherDown'):
    return getattr(pack_sim_messages,
                   'PACK_%s_SIZE' % (
                       message_type_name[len('kMessageType'):].upper() +
                       'MESSAGE'))
  else:
    return getattr(pack_avionics_messages,
                   'PACK_%s_SIZE' % (
                       message_type_name[len('kMessageType'):].upper() +
                       'MESSAGE'))


def GetMessageStructsMulticastGroups(message_types):
  """Get a dict of AIO message classes and multicast groups.

  Args:
    message_types: A list of message type names or short names.

  Returns:
    message_structs: A dict of AIO message classes indexed by message type enum.
    mcast_groups: A dict of multicast groups indexed by message type enum.
  """

  message_structs = {}
  mcast_groups = {}
  for msg_type in message_types:
    try:
      msg_struct = GetAioMessageStruct(msg_type)
    except AioClientException:
      msg_struct = None

    msg_enum = message_type_helper.Value(msg_type)
    message_structs[msg_enum] = msg_struct
    mcast_groups[msg_enum] = AioMessageTypeToIpAddressString(msg_enum)
  return message_structs, mcast_groups


class AioClient(object):
  """Vehicle for AIO interaction.

  All sending and receiving of AIO messages is done through this class.
  There are also static methods that provide convience functions for
  retrieving AIO specifics.
  """

  def __init__(self, message_types, allowed_sources=None, timeout=None):
    """Build AioClient instance.

    Args:
      message_types: A list of message type names to receive and send, e.g.
          ['kMessageTypeMotorStatus', 'kMessageTypeMotorStacking'].
      allowed_sources: A list of allowed AIO source nodes, e.g.
          ['kAioNodeMotorSbo', 'kAioNodeMotorSto'].  The list is used to
          filter calls to AioClient.Recv().  A value of None allows all sources.
      timeout: A number specifying the timeout for calls to AioClient.Recv() in
          seconds.  A value of None specifies no timeout.

    Returns:
      An instance of AioClient.

    Raises:
      AioClientException: An invalid AIO message type or node string was
          provided.
    """
    # Default sequence number of 0 so that rebooting AIO nodes, which
    # default to 1, aren't ignored.
    self._send_seq_nums = collections.defaultdict(lambda: 0)
    self._recv_seq_nums = collections.defaultdict(lambda: 0)
    self._recv_times = collections.defaultdict(lambda: 0)

    if not message_types:
      raise AioClientException('No message_types specified.')
    # Find message types.
    self._message_structs, self._mcast_groups = (
        GetMessageStructsMulticastGroups(message_types))

    # Find allowed AIO source nodes.
    if allowed_sources:
      self._allowed_sources = set()
      for source in allowed_sources:
        self._allowed_sources.add(aio_node_helper.Value(source))
    # Default to allowing all AIO nodes.
    else:
      self._allowed_sources = None

    # Bind to multi-cast groups.
    self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    self._sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL, 20)
    self._sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)

    self._sock.bind(('', network_config.UDP_PORT_AIO))

    try:
      for group in self._mcast_groups.values():
        self._sock.setsockopt(
            socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, struct.pack(
                '4sl', socket.inet_pton(socket.AF_INET, group),
                socket.INADDR_ANY))
    except socket.error:
      print ('Unable to add multicast group.  Linux is limited to 20 multicast '
             'groups by default.  You may need to increase your limit:\n'
             '  echo <new_limit> > /proc/sys/net/ipv4/igmp_max_memberships\n')
      raise

    if timeout:
      self._timeout = timeout
      self._sock.settimeout(timeout)
    else:
      self._timeout = None

  def _BuildHeader(self, source, message_type):
    """Build AIO header."""
    header = aio_header.AioHeader()
    header.version = aio_version.AIO_VERSION
    header.source = source
    header.type = message_type
    header.sequence = self._send_seq_nums[(source, message_type)]
    timestamp_type = c_helpers.GetFieldType(aio_header.AioHeader, 'timestamp')
    header.timestamp = timestamp_type(int(time.time() * 1e6))

    msg_key = (source, message_type)
    self._send_seq_nums[msg_key] = (self._send_seq_nums[msg_key] + 1) % 2**16

    return header

  def IsDuplicate(self, header, payload_string, cur_time):  # pylint: disable=unused-argument
    """Determine if AIO message is duplicate based on header.

    Uses past received sequence number and timestamp for a given AIO source
    and message type to determine if the message is a duplicate.  Can be
    overloaded to create custom de-duplication logic.

    Args:
      header: An aio_header.AioHeader for the message.
      payload_string: A string containing the big-endian message data.  Not
          used here.  Only present for overriding subclasses.
      cur_time: Time of received message.

    Returns:
      A boolean indicating if the received message is a duplicate.
    """
    last_seq = self._recv_seq_nums[(header.source, header.type)]
    last_time = self._recv_times[(header.source, header.type)]
    cur_seq = header.sequence

    # Sequence numbers expire after maximum latency.
    if cur_time - last_time < aio_header.AIO_EXPIRATION_TIME_US * 1e-6:
      # Expected duplication.
      if cur_seq == last_seq:
        return True
      # Out of order.
      if (cur_seq - last_seq) % 2**16 > aio_header.AIO_ACCEPTANCE_WINDOW:
        return True
    return False

  def Recv(self, accept_invalid=False):
    """Receive AIO messages.

    Receives AIO messages based on the parameters passed into __init__().
    Filters messages based on message type, source, and sequence number.  If
    a timeout was provided, Recv() will block for the specified time
    or until a suitable message is received.  Otherwise Recv() blocks
    indefinitely.

    Args:
      accept_invalid: Return raw messages with an invalid AIO version.

    Returns:
      A tuple, (source_addr, aio_header.AioHeader, <message>), of a received
      message that passed all of the specified filters.  If a suitable class
      was not found for the specified message type (e.g. kMessageTypeStdio),
      Recv() returns a raw string in its place.

    Raises:
      socket.timeout: A call to Recv() exceeded the specified timeout.
    """
    start_time = time.time()
    while True:
      if self._timeout and time.time() > start_time + self._timeout:
        raise socket.timeout

      # We assume that either ControlTelemetry or SimTelemetry will be the
      # largest message we need to support.
      max_size = max(aio_header.PACK_AIOHEADER_SIZE
                     + max(pack_control_telemetry.PACK_CONTROLTELEMETRY_SIZE,
                           pack_sim_telemetry.PACK_SIMTELEMETRY_SIZE),
                     network_config.MAX_UDP_PAYLOAD_SIZE)
      received, (source_addr, _) = self._sock.recvfrom(max_size)

      # Ensure sufficient data to unpack header.
      if len(received) < aio_header.PACK_AIOHEADER_SIZE:
        continue

      header = c_helpers.Unpack(received[:aio_header.PACK_AIOHEADER_SIZE],
                                aio_header.AioHeader)
      received = received[aio_header.PACK_AIOHEADER_SIZE:]

      # Ensure correct AIO_VERSION with exception for printf() outputs.
      valid_message = True
      if (header.version != aio_version.AIO_VERSION
          and header.type != aio_message_type.kMessageTypeStdio):
        if accept_invalid:
          valid_message = False
        else:
          continue

      # Ensure correct message type.
      if header.type not in self._message_structs:
        continue

      # Filter on source.
      if self._allowed_sources and header.source not in self._allowed_sources:
        continue

      # Return invalid messages if requested.
      if accept_invalid and not valid_message:
        return (source_addr, header, received)

      # AIO de-duplication.
      cur_time = time.time()
      if self.IsDuplicate(header, payload_string=received, cur_time=cur_time):
        continue
      self._recv_seq_nums[(header.source, header.type)] = header.sequence
      self._recv_times[(header.source, header.type)] = cur_time

      # Unpack data, or return raw bytes if there is no corresponding struct.
      if self._message_structs[header.type]:
        try:
          payload = c_helpers.Unpack(received,
                                     self._message_structs[header.type])
        except c_helpers.UnpackError:
          continue
        return (source_addr, header, payload)
      else:
        # Default to returning raw bytes if no suitable structure is found.
        return (source_addr, header, received)

  def _SendInternal(self, message, message_type_name, source, address):
    """Pack (if possible) and send a message."""
    type_enum = message_type_helper.Value(message_type_name)

    if (self._message_structs[type_enum]
        and not isinstance(message, self._message_structs[type_enum])):
      raise AioClientException('Actual type (%s) disagrees with enum name (%s).'
                               % (type(message).__name__, message_type_name))

    header = c_helpers.Pack(self._BuildHeader(aio_node_helper.Value(source),
                                              type_enum))
    assert header

    packed = c_helpers.Pack(message)
    if packed:
      contents = buffer(header) + buffer(packed)
    else:
      contents = buffer(header) + buffer(message)

    self._sock.sendto(contents, address)

  def Send(self, message, message_type_name, source):
    """Send multicast AIO message.

    Args:
      message: A message of the type specified by message_type_name.
      message_type_name: A MessageType enum name of the AIO message type to
          be sent,  e.g. 'kMessageTypeMotorStatus'.
      source: An AioNode enum name specifying the AIO node to spoof as the
          sender of the message, e.g. 'kAioNodeMotorSbo'.

    Raises:
      AioClientException: An invalid AIO message type or node string was
          provided.
    """
    type_enum = message_type_helper.Value(message_type_name)
    address = (self._mcast_groups[type_enum], network_config.UDP_PORT_AIO)
    self._SendInternal(message, message_type_name, source, address)

  def Close(self):
    """Close UDP socket associated with AioClient instance."""
    self._sock.close()
